import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns


def plot_heatmap(my_df, fig_name, y_labels=None):
    if my_df.T.shape[1] == 2:  # num columns
        plt.figure(figsize=(64, 4))
    else:
        plt.figure(figsize=(64, 20))
    ax = sns.heatmap(my_df, cbar_kws={"orientation": "vertical", "pad": 0.01})
    ax.collections[0].colorbar.set_label(
        "Gene expression level log(TPM+1)", labelpad=20, fontsize=14
    )

    if y_labels is not None:
        ax.set_yticklabels(y_labels)

    # turn the axis label
    for item in ax.get_yticklabels():
        item.set_rotation(0)

    for item in ax.get_xticklabels():
        item.set_rotation(90)
    ax.set_xlabel("Voltage gated ion channel genes", fontsize=14)
    plt.rc("font", size=18)
    plt.tight_layout()
    # save figure
    plt.savefig(fig_name, dpi=300)
    plt.show()


def plot_mean_genes_expression(combined_df, _expressed_genes, fig_name):
    plt.figure(figsize=(14, 5))
    filtered_df = combined_df[combined_df["Gene"].isin(_expressed_genes)]
    yerr = filtered_df[["Spp_SE", "Ecel_SE"]].to_numpy().T
    axes = filtered_df.plot(
        x="Gene",
        y=["Spp_Mean", "Ecel_Mean"],
        kind="bar",
        yerr=yerr,
        alpha=0.9,
        color=["darkorchid", "springgreen"],
    )  # "dodgerblue"
    axes.set_ylabel("Mean gene expression level log(TPM+1)", fontsize=12)
    axes.set_xlabel("Select voltage gated ion channel genes", fontsize=12)

    axes.tick_params(axis="y", labelsize=10)
    axes.tick_params(axis="x", labelsize=12)

    axes.set_xticklabels(axes.get_xticklabels(), rotation=45, fontsize=12, ha="right")

    axes.spines.top.set_visible(False)
    axes.spines.right.set_visible(False)

    plt.legend(["Spp", "Ecel"], fontsize=12)
    plt.tight_layout()
    plt.savefig(fig_name, dpi=300)
    plt.show()


# file names per subtype
Spp_exp = [
    "patchseq_170204_10_S213",
    "patchseq_170804_740_A5_S5_L004",
    "patchseq_170804_743_G5_S68_L004",
    "patchseq_170804_745_B6_S17_L004",
    "patchseq_170804_747_E6_S49_L004",
    "patchseq_170804_323_H7_S79_L004",
    "patchseq_170804_324_A8_S8_L004",
    "patchseq_170804_326_C8_S30_L004",
    "patchseq_170804_333_B9_S20_L004",
    "patchseq_170804_336_E9_S52_L004",
    "patchseq_170804_338_G9_S72_L004",
    "patchseq_170804_339_H9_S81_L004",
    "patchseq_170204_13_S249",
    "patchseq_170204_14_S261",
    "patchseq_170204_20_S238",
    "patchseq_170204_29_S251",
    "patchseq_170204_31_S275",
    "patchseq_170204_32_S287",
    "patchseq_170204_33_S204",
    "patchseq_170204_3_S224",
    "patchseq_170204_7_S272",
    "patchseq_170204_8_S284",
    "patchseq_170204_9_S201",
]

Ecel_exp = [
    "patchseq_170204_17_S202",
    "patchseq_170204_4_S236",
    "patchseq_170204_2_S212",
    "patchseq_170804_M456_D5_S38_L004",
    "patchseq_170804_M457_F5_S58_L004",
    "patchseq_170804_M458_H5_S78_L004",
    "patchseq_170804_748_G6_S69_L004",
    "patchseq_170804_316_A7_S7_L004",
    "patchseq_170804_317_B7_S18_L004",
    "patchseq_170804_cell760_C7_S29_L004",
    "patchseq_170804_cell762_F7_S60_L004",
    "patchseq_170804_334_C9_S31_L004",
    "patchseq_170804_337_F9_S62_L004",
    "patchseq_170204_18_S214",
    "patchseq_170204_22_S262",
    "patchseq_170204_23_S274",
    "patchseq_170204_24_S286",
    "patchseq_170204_26_S215",
    "patchseq_170204_28_S239",
    "patchseq_170204_5_S248",
    "patchseq_170804_330_G8_S71_L004",
]
genes = [
    "Cacna1g",
    "Cacna1h",
    "Cacna1i",
    "Kcnn1",
    "Kcnn2",
]

if __name__ == "__main__":

    de_df = pd.read_pickle("Voltage_gated_ions_patchseq_genes_epression_levels.pkl")
    de_df.rename(columns={de_df.columns[0]: "Gene"}, inplace=True)
    de_df = de_df.set_index("Gene").loc[genes].T
    print(de_df)
    df_spp = de_df.loc[Spp_exp].melt()
    df_spp["cell_type"] = "spp"
    df_ecel = de_df.loc[Ecel_exp].melt()
    df_ecel["cell_type"] = "ecel"
    # df_all = de_df.melt()
    # df_all["cell_type"] = "all"
    plot_df = pd.concat([df_ecel, df_spp])
    plt.figure(figsize=(5, 3))
    ax = plt.gca()
    sns.stripplot(data=plot_df, x="Gene", y="value", hue="cell_type", dodge=True, ax=ax)
    sns.boxplot(data=plot_df, x="Gene", y="value", hue="cell_type", fill=False, ax=ax)
    plt.tight_layout()
    plt.savefig("boxplot.pdf")

    de_df = pd.read_pickle("Voltage_gated_ions_patchseq_genes_epression_levels.pkl")
    de_df.rename(columns={de_df.columns[0]: "Gene"}, inplace=True)
    de_df = de_df.set_index("Gene")
    Spp_df = de_df.loc[:, de_df.columns.isin(Spp_exp)]
    Spp_col_names = []
    for i, val in enumerate(Spp_df.columns):
        name = "Spp1_cell" + str(i + 1)
        Spp_col_names.append(name)

    Spp_df.columns = Spp_col_names

    Ecel_df = de_df.loc[:, de_df.columns.isin(Ecel_exp)]
    Ecel_col_names = []
    for i, val in enumerate(Ecel_df.columns):
        name = "Ecel1_cell" + str(i + 1)
        Ecel_col_names.append(name)
    Ecel_df.columns = Ecel_col_names

    Spp_and_Ecel = pd.concat([Spp_df, Ecel_df], axis=1)
    print(Spp_and_Ecel)
    plot_heatmap(Spp_and_Ecel.T, "AllVoltageGatedChannels_gene_exp_Spp_and_Ecel.pdf")

    # combine Spp and Ecel mean into one dataframe
    df = pd.DataFrame(index=Spp_df.index)
    df["Spp_Mean"] = Spp_df.mean(axis=1)
    df["Spp_SE"] = Spp_df.std(axis=1)

    df["Ecel_Mean"] = Ecel_df.mean(axis=1)
    df["Ecel_SE"] = Ecel_df.std(axis=1)
    print(df.index.to_list())

    # combined_df = pd.concat([Spp_df, Ecel_df], axis=1)
    # combined_df = combined_df.loc[:, ~combined_df.T.duplicated()]
    print(df)

    plot_mean_genes_expression(
        df.reset_index(), genes, "SelectedChannels_mean_gene_exp_Spp_and_Ecel.pdf"
    )
    df_mean_Ecel = Ecel_df.T.mean(axis=0)
    df_mean_Spp = Spp_df.T.mean(axis=0)

    mean_expr = pd.concat([df_mean_Spp, df_mean_Ecel], keys=["Spp mean", "Ecel mean"], axis=1)
    mean_expr["diff"] = (mean_expr["Spp mean"] - mean_expr["Ecel mean"]).abs()
    sorted = mean_expr.sort_values("diff", ascending=False)
    my_df = sorted[["Spp mean", "Ecel mean"]].copy()

    plot_heatmap(my_df.T, "AllVoltageGatedChannels_MEAN_gene_exp_Spp_and_Ecel.pdf", y_labels=None)
